package dbconn

import (
	"context"
	"errors"
	"fmt"
	"gitee.com/lailonghui/vehicle-supervision-framework/pkg/asserts"
	"gitee.com/lailonghui/vehicle-supervision-framework/pkg/ferror"
	"gitee.com/lailonghui/vehicle-supervision-framework/pkg/loggers"
	"gitee.com/lailonghui/vehicle-supervision-framework/pkg/passwords"
	"gitee.com/shiqiyue/xd-bi/internal/modules/cr/enums"
	"gitee.com/shiqiyue/xd-bi/internal/modules/cr/model"
	"gitee.com/shiqiyue/xd-bi/internal/pkg/instance"
	"go.uber.org/zap"
	"gorm.io/driver/postgres"
	"gorm.io/gorm"
	"gorm.io/gorm/logger"
	gormopentracing "gorm.io/plugin/opentracing"
	"log"
	"os"
	"sync"
	"time"
)

var connCache = map[string]*cacheConn{}
var lock = sync.Mutex{}

type cacheConn struct {
	db        *gorm.DB
	updatedAt time.Time
}

// 获取数据库链接
func GetConn(ctx context.Context, reportDb model.ReportDb) (*gorm.DB, error) {

	if reportDb.Type == enums.REPORT_DB_TYPE_POSTGRES {
		dbUrl, err := getPostgresDsn(reportDb)
		if err != nil {
			return nil, err
		}
		if instance.GetConfig().Dev.ForceStageDb {
			dbUrl = "host=120.37.177.122 user=postgres password=cczo6ku$N&CiDEURxyvDX38Ij@eHz6 dbname=vehicle port=18001 sslmode=disable TimeZone=Asia/Shanghai"
		}

		DB := getConnByDsn(dbUrl, reportDb)
		if DB != nil {
			return DB.WithContext(ctx), nil
		}
		loggers.Info("创建新的数据库链接", ctx, zap.String("DSN", dbUrl))
		DB, err = createNewConn(ctx, dbUrl, reportDb)
		if err != nil {
			return nil, err
		}

		return DB.WithContext(ctx), nil
	}

	return nil, errors.New("不支持的数据库类型")
}

func createNewConn(ctx context.Context, dsn string, reportDb model.ReportDb) (*gorm.DB, error) {
	lock.Lock()
	defer lock.Unlock()
	if d := getConnByDsn(dsn, reportDb); d != nil {
		return d, nil
	}

	loggers.Info("准备链接数据库", ctx, zap.String("数据库链接", dsn))

	newLogger := logger.New(
		log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
		logger.Config{
			SlowThreshold: time.Second, // 慢 SQL 阈值
			LogLevel:      logger.Info, // Log level
			Colorful:      true,        // 禁用彩色打印
		},
	)
	DB, err := gorm.Open(postgres.Open(dsn), &gorm.Config{QueryFields: true, Logger: newLogger})
	if err != nil {
		return nil, ferror.Wrap("连接PG失败", err)
	}
	// 设置连接池
	sqlDB, err := DB.DB()
	asserts.Nil(err, err)
	if reportDb.MaxIdleConns > 0 {

		// SetMaxIdleConns 设置空闲连接池中连接的最大数量
		sqlDB.SetMaxIdleConns(reportDb.MaxIdleConns)

	}
	if reportDb.MaxOpenConns > 0 {
		// SetMaxOpenConns 设置打开数据库连接的最大数量。
		sqlDB.SetMaxOpenConns(reportDb.MaxOpenConns)
	}
	if reportDb.ConnMaxLifetime > 0 {
		// SetConnMaxLifetime 设置了连接可复用的最大时间。
		t := time.Duration(reportDb.ConnMaxLifetime)
		sqlDB.SetConnMaxLifetime(t * time.Second)
	}
	// opentracing
	err = DB.Use(gormopentracing.New())
	if err != nil {
		return nil, err
	}
	connCache[dsn] = &cacheConn{
		db:        DB,
		updatedAt: reportDb.UpdatedAt,
	}
	return DB, nil
}

func getConnByDsn(dsn string, reportDb model.ReportDb) *gorm.DB {
	cacheDb := connCache[dsn]
	if cacheDb == nil {
		return nil
	}
	if !cacheDb.updatedAt.Equal(reportDb.UpdatedAt) {
		return nil
	}
	return cacheDb.db
}

// 获取postgres的dsn
func getPostgresDsn(db model.ReportDb) (string, error) {
	passwordBs, err := passwords.AesDeCryptBase64(db.Password, []byte(instance.GetConfig().App.DbPasswordPrivateKey))
	if err != nil {
		return "", nil
	}
	return fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=%s TimeZone=%s",
		db.Host,
		db.Username,
		string(passwordBs),
		db.DbName,
		db.Port,
		"disable",
		"Asia/Shanghai"), nil
}
